/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.runtime.operators.hash;

import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypePairComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.LongComparator;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.runtime.TupleComparator;
import org.apache.flink.api.java.typeutils.runtime.TupleSerializer;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.operators.testutils.UniformStringPairGenerator;
import org.apache.flink.runtime.operators.testutils.types.IntList;
import org.apache.flink.runtime.operators.testutils.types.IntListComparator;
import org.apache.flink.runtime.operators.testutils.types.IntListPairComparator;
import org.apache.flink.runtime.operators.testutils.types.IntListSerializer;
import org.apache.flink.runtime.operators.testutils.types.IntPair;
import org.apache.flink.runtime.operators.testutils.types.IntPairComparator;
import org.apache.flink.runtime.operators.testutils.types.IntPairListPairComparator;
import org.apache.flink.runtime.operators.testutils.types.IntPairPairComparator;
import org.apache.flink.runtime.operators.testutils.types.IntPairSerializer;
import org.apache.flink.runtime.operators.testutils.types.StringPair;
import org.apache.flink.runtime.operators.testutils.types.StringPairComparator;
import org.apache.flink.runtime.operators.testutils.types.StringPairPairComparator;
import org.apache.flink.runtime.operators.testutils.types.StringPairSerializer;
import org.apache.flink.util.MutableObjectIterator;

import org.junit.Test;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

import static org.junit.Assert.*;

public abstract class MutableHashTableTestBase {

    protected static final long RANDOM_SEED = 76518743207143L;

    private static final int KEY_VALUE_DIFF = 1021;

    public final int PAGE_SIZE = 16 * 1024;

    public final TypeSerializer<IntPair> intPairSerializer = new IntPairSerializer();

    public final TypeComparator<IntPair> intPairComparator = new IntPairComparator();

    public final TypePairComparator<IntPair, IntPair> pairComparator = new IntPairPairComparator();

    private static final int MAX_LIST_SIZE = 8;

    public final TypeSerializer<IntList> serializerV = new IntListSerializer();

    public final TypeComparator<IntList> comparatorV = new IntListComparator();

    public final TypePairComparator<IntList, IntList> pairComparatorV = new IntListPairComparator();

    public final TypePairComparator<IntPair, IntList> pairComparatorPL =
            new IntPairListPairComparator();

    @SuppressWarnings("unchecked")
    protected TupleSerializer<Tuple2<Long, String>> tuple2LongStringSerializer =
            new TupleSerializer<>(
                    (Class<Tuple2<Long, String>>) (Class<?>) Tuple2.class,
                    new TypeSerializer<?>[] {LongSerializer.INSTANCE, StringSerializer.INSTANCE});

    protected TupleComparator<Tuple2<Long, String>> tuple2LongStringComparator =
            new TupleComparator<>(
                    new int[] {0},
                    new TypeComparator<?>[] {new LongComparator(true)},
                    new TypeSerializer<?>[] {LongSerializer.INSTANCE});

    public final int SIZE = 75;

    public final int NUM_PAIRS = 100000;

    public final int NUM_LISTS = 100000;

    protected final int ADDITIONAL_MEM = 100;

    private final int NUM_REWRITES = 10;

    public final TypeSerializer<StringPair> serializerS = new StringPairSerializer();

    public final TypeComparator<StringPair> comparatorS = new StringPairComparator();

    private final TypePairComparator<StringPair, StringPair> pairComparatorS =
            new StringPairPairComparator();

    protected abstract <T> AbstractMutableHashTable<T> getHashTable(
            TypeSerializer<T> serializer, TypeComparator<T> comparator, List<MemorySegment> memory);

    @Test
    public void testDifferentProbers() {
        final int NUM_MEM_PAGES = 32 * NUM_PAIRS / PAGE_SIZE;
        AbstractMutableHashTable<IntPair> table =
                getHashTable(intPairSerializer, intPairComparator, getMemory(NUM_MEM_PAGES));

        AbstractHashTableProber<IntPair, IntPair> prober1 =
                table.getProber(intPairComparator, pairComparator);
        AbstractHashTableProber<IntPair, IntPair> prober2 =
                table.getProber(intPairComparator, pairComparator);

        assertFalse(prober1 == prober2);

        table.close(); // (This also tests calling close without calling open first.)
        assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size());
    }

    @Test
    public void testBuildAndRetrieve() throws Exception {
        final int NUM_MEM_PAGES = 32 * NUM_PAIRS / PAGE_SIZE;
        AbstractMutableHashTable<IntPair> table =
                getHashTable(intPairSerializer, intPairComparator, getMemory(NUM_MEM_PAGES));

        final Random rnd = new Random(RANDOM_SEED);
        final IntPair[] pairs = getRandomizedIntPairs(NUM_PAIRS, rnd);

        table.open();

        for (int i = 0; i < NUM_PAIRS; i++) {
            table.insert(pairs[i]);
        }

        AbstractHashTableProber<IntPair, IntPair> prober =
                table.getProber(intPairComparator, pairComparator);
        IntPair target = new IntPair();

        for (int i = 0; i < NUM_PAIRS; i++) {
            assertNotNull(prober.getMatchFor(pairs[i], target));
            assertEquals(pairs[i].getValue(), target.getValue());
        }

        table.close();
        assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size());
    }

    @Test
    public void testEntryIterator() throws Exception {
        final int NUM_MEM_PAGES = SIZE * NUM_LISTS / PAGE_SIZE;
        AbstractMutableHashTable<IntList> table =
                getHashTable(serializerV, comparatorV, getMemory(NUM_MEM_PAGES));

        final Random rnd = new Random(RANDOM_SEED);
        final IntList[] lists = getRandomizedIntLists(NUM_LISTS, rnd);

        table.open();
        int result = 0;
        for (int i = 0; i < NUM_LISTS; i++) {
            table.insert(lists[i]);
            result += lists[i].getKey();
        }

        MutableObjectIterator<IntList> iter = table.getEntryIterator();
        IntList target = new IntList();

        int sum = 0;
        while ((target = iter.next(target)) != null) {
            sum += target.getKey();
        }
        table.close();

        assertTrue(sum == result);
        assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size());
    }

    @Test
    public void testMultipleProbers() throws Exception {
        final int NUM_MEM_PAGES = SIZE * NUM_LISTS / PAGE_SIZE;
        AbstractMutableHashTable<IntList> table =
                getHashTable(serializerV, comparatorV, getMemory(NUM_MEM_PAGES));

        final Random rnd = new Random(RANDOM_SEED);
        final IntList[] lists = getRandomizedIntLists(NUM_LISTS, rnd);
        final IntPair[] pairs = getRandomizedIntPairs(NUM_LISTS, rnd);

        table.open();
        for (int i = 0; i < NUM_LISTS; i++) {
            table.insert(lists[i]);
        }

        AbstractHashTableProber<IntList, IntList> listProber =
                table.getProber(comparatorV, pairComparatorV);

        AbstractHashTableProber<IntPair, IntList> pairProber =
                table.getProber(intPairComparator, pairComparatorPL);

        IntList target = new IntList();
        for (int i = 0; i < NUM_LISTS; i++) {
            assertNotNull(pairProber.getMatchFor(pairs[i], target));
            assertNotNull(listProber.getMatchFor(lists[i], target));
            assertArrayEquals(lists[i].getValue(), target.getValue());
        }
        table.close();
        assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size());
    }

    @Test
    public void testVariableLengthBuildAndRetrieve() throws Exception {
        final int NUM_MEM_PAGES = SIZE * NUM_LISTS / PAGE_SIZE;
        AbstractMutableHashTable<IntList> table =
                getHashTable(serializerV, comparatorV, getMemory(NUM_MEM_PAGES));

        final Random rnd = new Random(RANDOM_SEED);
        final IntList[] lists = getRandomizedIntLists(NUM_LISTS, rnd);

        table.open();

        for (int i = 0; i < NUM_LISTS; i++) {
            try {
                table.insert(lists[i]);
            } catch (Exception e) {
                throw e;
            }
        }

        AbstractHashTableProber<IntList, IntList> prober =
                table.getProber(comparatorV, pairComparatorV);
        IntList target = new IntList();

        for (int i = 0; i < NUM_LISTS; i++) {
            assertNotNull(prober.getMatchFor(lists[i], target));
            assertArrayEquals(lists[i].getValue(), target.getValue());
        }

        final IntList[] overwriteLists = getRandomizedIntLists(NUM_LISTS, rnd);

        // test replacing
        for (int i = 0; i < NUM_LISTS; i++) {
            table.insertOrReplaceRecord(overwriteLists[i]);
        }

        for (int i = 0; i < NUM_LISTS; i++) {
            assertNotNull("" + i, prober.getMatchFor(overwriteLists[i], target));
            assertArrayEquals(overwriteLists[i].getValue(), target.getValue());
        }

        table.close();
        assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size());
    }

    @Test
    public void testVariableLengthBuildAndRetrieveMajorityUpdated() throws Exception {
        final int NUM_MEM_PAGES = SIZE * NUM_LISTS / PAGE_SIZE;
        AbstractMutableHashTable<IntList> table =
                getHashTable(serializerV, comparatorV, getMemory(NUM_MEM_PAGES));

        final Random rnd = new Random(RANDOM_SEED);
        final IntList[] lists = getRandomizedIntLists(NUM_LISTS, rnd);

        table.open();

        for (int i = 0; i < NUM_LISTS; i++) {
            table.insert(lists[i]);
        }

        AbstractHashTableProber<IntList, IntList> prober =
                table.getProber(comparatorV, pairComparatorV);
        IntList target = new IntList();

        for (int i = 0; i < NUM_LISTS; i++) {
            assertNotNull(prober.getMatchFor(lists[i], target));
            assertArrayEquals(lists[i].getValue(), target.getValue());
        }

        final IntList[] overwriteLists = getRandomizedIntLists(NUM_LISTS, rnd);

        // test replacing
        for (int i = 0; i < NUM_LISTS; i++) {
            if (i % 100 != 0) {
                table.insertOrReplaceRecord(overwriteLists[i]);
                lists[i] = overwriteLists[i];
            }
        }

        for (int i = 0; i < NUM_LISTS; i++) {
            assertNotNull("" + i, prober.getMatchFor(lists[i], target));
            assertArrayEquals(lists[i].getValue(), target.getValue());
        }

        table.close();
        assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size());
    }

    @Test
    public void testVariableLengthBuildAndRetrieveMinorityUpdated() throws Exception {
        final int NUM_LISTS = 20000;
        final int NUM_MEM_PAGES = SIZE * NUM_LISTS / PAGE_SIZE;
        AbstractMutableHashTable<IntList> table =
                getHashTable(serializerV, comparatorV, getMemory(NUM_MEM_PAGES));

        final int STEP_SIZE = 100;

        final Random rnd = new Random(RANDOM_SEED);
        final IntList[] lists = getRandomizedIntLists(NUM_LISTS, rnd);

        table.open();

        for (int i = 0; i < NUM_LISTS; i++) {
            table.insert(lists[i]);
        }

        AbstractHashTableProber<IntList, IntList> prober =
                table.getProber(comparatorV, pairComparatorV);
        IntList target = new IntList();

        for (int i = 0; i < NUM_LISTS; i++) {
            assertNotNull(prober.getMatchFor(lists[i], target));
            assertArrayEquals(lists[i].getValue(), target.getValue());
        }

        final IntList[] overwriteLists = getRandomizedIntLists(NUM_LISTS / STEP_SIZE, rnd);

        // test replacing
        for (int i = 0; i < NUM_LISTS; i += STEP_SIZE) {
            overwriteLists[i / STEP_SIZE].setKey(
                    overwriteLists[i / STEP_SIZE].getKey() * STEP_SIZE);
            table.insertOrReplaceRecord(overwriteLists[i / STEP_SIZE]);
            lists[i] = overwriteLists[i / STEP_SIZE];
        }

        for (int i = 0; i < NUM_LISTS; i++) {
            assertNotNull(prober.getMatchFor(lists[i], target));
            assertArrayEquals(lists[i].getValue(), target.getValue());
        }

        table.close();
        assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size());
    }

    @Test
    public void testRepeatedBuildAndRetrieve() throws Exception {
        final int NUM_MEM_PAGES = SIZE * NUM_LISTS / PAGE_SIZE;
        AbstractMutableHashTable<IntList> table =
                getHashTable(serializerV, comparatorV, getMemory(NUM_MEM_PAGES));

        final Random rnd = new Random(RANDOM_SEED);
        final IntList[] lists = getRandomizedIntLists(NUM_LISTS, rnd);

        table.open();

        for (int i = 0; i < NUM_LISTS; i++) {
            try {
                table.insert(lists[i]);
            } catch (Exception e) {
                throw e;
            }
        }

        AbstractHashTableProber<IntList, IntList> prober =
                table.getProber(comparatorV, pairComparatorV);
        IntList target = new IntList();

        for (int i = 0; i < NUM_LISTS; i++) {
            assertNotNull(prober.getMatchFor(lists[i], target));
            assertArrayEquals(lists[i].getValue(), target.getValue());
        }

        IntList[] overwriteLists;

        for (int k = 0; k < NUM_REWRITES; k++) {
            overwriteLists = getRandomizedIntLists(NUM_LISTS, rnd);
            // test replacing
            for (int i = 0; i < NUM_LISTS; i++) {
                table.insertOrReplaceRecord(overwriteLists[i]);
            }

            for (int i = 0; i < NUM_LISTS; i++) {
                assertNotNull("" + i, prober.getMatchFor(overwriteLists[i], target));
                assertArrayEquals(overwriteLists[i].getValue(), target.getValue());
            }
        }

        table.close();
        assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size());
    }

    @Test
    public void testProberUpdate() throws Exception {
        final int NUM_MEM_PAGES = SIZE * NUM_LISTS / PAGE_SIZE;
        AbstractMutableHashTable<IntList> table =
                getHashTable(serializerV, comparatorV, getMemory(NUM_MEM_PAGES));

        final Random rnd = new Random(RANDOM_SEED);
        final IntList[] lists = getRandomizedIntLists(NUM_LISTS, rnd);

        table.open();

        for (int i = 0; i < NUM_LISTS; i++) {
            table.insert(lists[i]);
        }

        final IntList[] overwriteLists = getRandomizedIntLists(NUM_LISTS, rnd);

        AbstractHashTableProber<IntList, IntList> prober =
                table.getProber(comparatorV, pairComparatorV);
        IntList target = new IntList();

        for (int i = 0; i < NUM_LISTS; i++) {
            assertNotNull("" + i, prober.getMatchFor(lists[i], target));
            assertArrayEquals(lists[i].getValue(), target.getValue());
            prober.updateMatch(overwriteLists[i]);
        }

        for (int i = 0; i < NUM_LISTS; i++) {
            assertNotNull("" + i, prober.getMatchFor(overwriteLists[i], target));
            assertArrayEquals(overwriteLists[i].getValue(), target.getValue());
        }

        table.close();
        assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size());
    }

    @Test
    public void testVariableLengthStringBuildAndRetrieve() throws IOException {
        final int NUM_MEM_PAGES = 40 * NUM_PAIRS / PAGE_SIZE;
        AbstractMutableHashTable<StringPair> table =
                getHashTable(serializerS, comparatorS, getMemory(NUM_MEM_PAGES));

        MutableObjectIterator<StringPair> buildInput =
                new UniformStringPairGenerator(NUM_PAIRS, 1, false);

        MutableObjectIterator<StringPair> probeTester =
                new UniformStringPairGenerator(NUM_PAIRS, 1, false);

        MutableObjectIterator<StringPair> updater =
                new UniformStringPairGenerator(NUM_PAIRS, 1, false);

        MutableObjectIterator<StringPair> updateTester =
                new UniformStringPairGenerator(NUM_PAIRS, 1, false);

        table.open();

        StringPair target = new StringPair();
        while (buildInput.next(target) != null) {
            table.insert(target);
        }

        AbstractHashTableProber<StringPair, StringPair> prober =
                table.getProber(comparatorS, pairComparatorS);
        StringPair temp = new StringPair();
        while (probeTester.next(target) != null) {
            assertNotNull("" + target.getKey(), prober.getMatchFor(target, temp));
            assertEquals(temp.getValue(), target.getValue());
        }

        while (updater.next(target) != null) {
            target.setValue(target.getValue());
            table.insertOrReplaceRecord(target);
        }

        while (updateTester.next(target) != null) {
            assertNotNull(prober.getMatchFor(target, temp));
            assertEquals(target.getValue(), temp.getValue());
        }

        table.close();
        assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size());
    }

    protected static IntPair[] getRandomizedIntPairs(int num, Random rnd) {
        IntPair[] pairs = new IntPair[num];

        // create all the pairs, dense
        for (int i = 0; i < num; i++) {
            pairs[i] = new IntPair(i, i + KEY_VALUE_DIFF);
        }

        // randomly swap them
        for (int i = 0; i < 2 * num; i++) {
            int pos1 = rnd.nextInt(num);
            int pos2 = rnd.nextInt(num);

            IntPair tmp = pairs[pos1];
            pairs[pos1] = pairs[pos2];
            pairs[pos2] = tmp;
        }

        return pairs;
    }

    protected static IntList[] getRandomizedIntLists(int num, Random rnd) {
        IntList[] lists = new IntList[num];
        for (int i = 0; i < num; i++) {
            int[] value = new int[rnd.nextInt(MAX_LIST_SIZE) + 1];
            // int[] value = new int[MAX_LIST_SIZE-1];
            for (int j = 0; j < value.length; j++) {
                value[j] = -rnd.nextInt(Integer.MAX_VALUE);
            }
            lists[i] = new IntList(i, value);
        }

        return lists;
    }

    public List<MemorySegment> getMemory(int numPages) {
        return getMemory(numPages, PAGE_SIZE);
    }

    private static List<MemorySegment> getMemory(int numSegments, int segmentSize) {
        ArrayList<MemorySegment> list = new ArrayList<MemorySegment>(numSegments);
        for (int i = 0; i < numSegments; i++) {
            list.add(MemorySegmentFactory.allocateUnpooledSegment(segmentSize));
        }
        return list;
    }
}
